| from einops import rearrange, repeat |
| import torch |
| import torch.nn as nn |
| from ..wanvideo.modules.attention import attention |
|
|
| def timestep_transform( |
| t, |
| shift=5.0, |
| num_timesteps=1000, |
| ): |
| t = t / num_timesteps |
| |
| new_t = shift * t / (1 + (shift - 1) * t) |
| new_t = new_t * num_timesteps |
| return new_t |
|
|
| def add_noise( |
| original_samples: torch.FloatTensor, |
| noise: torch.FloatTensor, |
| timesteps: torch.IntTensor, |
| ) -> torch.FloatTensor: |
| """ |
| compatible with diffusers add_noise() |
| """ |
| timesteps = timesteps.float() / 1000 |
| timesteps = timesteps.view(timesteps.shape + (1,) * (len(noise.shape)-1)) |
|
|
| return (1 - timesteps) * original_samples + timesteps * noise |
|
|
| def normalize_and_scale(column, source_range, target_range, epsilon=1e-8): |
|
|
| source_min, source_max = source_range |
| new_min, new_max = target_range |
| |
| normalized = (column - source_min) / (source_max - source_min + epsilon) |
| scaled = normalized * (new_max - new_min) + new_min |
| return scaled |
|
|
| def rotate_half(x): |
| x = rearrange(x, "... (d r) -> ... d r", r=2) |
| x1, x2 = x.unbind(dim=-1) |
| x = torch.stack((-x2, x1), dim=-1) |
| return rearrange(x, "... d r -> ... (d r)") |
|
|
| def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, mode='mean', attn_bias=None): |
| |
| ref_k = ref_k.to(visual_q.dtype).to(visual_q.device) |
| scale = 1.0 / visual_q.shape[-1] ** 0.5 |
| visual_q = visual_q * scale |
| visual_q = visual_q.transpose(1, 2) |
| ref_k = ref_k.transpose(1, 2) |
| attn = visual_q @ ref_k.transpose(-2, -1) |
|
|
| if attn_bias is not None: |
| attn = attn + attn_bias |
|
|
| x_ref_attn_map_source = attn.softmax(-1) |
|
|
|
|
| x_ref_attn_maps = [] |
| ref_target_masks = ref_target_masks.to(visual_q.dtype) |
| x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype) |
|
|
| for class_idx, ref_target_mask in enumerate(ref_target_masks): |
| ref_target_mask = ref_target_mask[None, None, None, ...] |
| x_ref_attnmap = x_ref_attn_map_source * ref_target_mask |
| x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() |
| x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) |
| |
| if mode == 'mean': |
| x_ref_attnmap = x_ref_attnmap.mean(-1) |
| elif mode == 'max': |
| x_ref_attnmap = x_ref_attnmap.max(-1) |
| |
| x_ref_attn_maps.append(x_ref_attnmap) |
| |
| del attn, x_ref_attn_map_source |
|
|
| return torch.concat(x_ref_attn_maps, dim=0) |
|
|
| def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2): |
| """Args: |
| query (torch.tensor): B M H K |
| key (torch.tensor): B M H K |
| shape (tuple): (N_t, N_h, N_w) |
| ref_target_masks: [B, N_h * N_w] |
| """ |
|
|
| N_t, N_h, N_w = shape |
| |
| x_seqlens = N_h * N_w |
| ref_k = ref_k[:, :x_seqlens] |
| _, seq_lens, heads, _ = visual_q.shape |
| class_num, _ = ref_target_masks.shape |
| x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q.device).to(visual_q.dtype) |
|
|
| split_chunk = heads // split_num |
| |
| for i in range(split_num): |
| x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks) |
| x_ref_attn_maps += x_ref_attn_maps_perhead |
| |
| return x_ref_attn_maps / split_num |
|
|
| class RotaryPositionalEmbedding1D(nn.Module): |
|
|
| def __init__(self, |
| head_dim, |
| ): |
| super().__init__() |
| self.head_dim = head_dim |
| self.base = 10000 |
|
|
| def precompute_freqs_cis_1d(self, pos_indices): |
|
|
| freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) |
| freqs = freqs.to(pos_indices.device) |
| freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs) |
| freqs = repeat(freqs, "... n -> ... (n r)", r=2) |
| return freqs |
|
|
| def forward(self, x, pos_indices): |
| """1D RoPE. |
| |
| Args: |
| query (torch.tensor): [B, head, seq, head_dim] |
| pos_indices (torch.tensor): [seq,] |
| Returns: |
| query with the same shape as input. |
| """ |
| freqs_cis = self.precompute_freqs_cis_1d(pos_indices) |
|
|
| x_ = x.float() |
|
|
| freqs_cis = freqs_cis.float().to(x.device) |
| cos, sin = freqs_cis.cos(), freqs_cis.sin() |
| cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') |
| x_ = (x_ * cos) + (rotate_half(x_) * sin) |
|
|
| return x_.type_as(x) |
|
|
| class AudioProjModel(nn.Module): |
| def __init__( |
| self, |
| seq_len=5, |
| seq_len_vf=12, |
| blocks=12, |
| channels=768, |
| intermediate_dim=512, |
| output_dim=768, |
| context_tokens=32, |
| norm_output_audio=False, |
| ): |
| super().__init__() |
|
|
| self.seq_len = seq_len |
| self.blocks = blocks |
| self.channels = channels |
| self.input_dim = seq_len * blocks * channels |
| self.input_dim_vf = seq_len_vf * blocks * channels |
| self.intermediate_dim = intermediate_dim |
| self.context_tokens = context_tokens |
| self.output_dim = output_dim |
|
|
| |
| self.proj1 = nn.Linear(self.input_dim, intermediate_dim) |
| self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim) |
| self.proj2 = nn.Linear(intermediate_dim, intermediate_dim) |
| self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim) |
| self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity() |
|
|
| def forward(self, audio_embeds, audio_embeds_vf): |
| video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1] |
| B, _, _, S, C = audio_embeds.shape |
|
|
| |
| audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") |
| batch_size, window_size, blocks, channels = audio_embeds.shape |
| audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) |
|
|
| |
| audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c") |
| batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape |
| audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf) |
|
|
| |
| audio_embeds = torch.relu(self.proj1(audio_embeds)) |
| audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf)) |
| audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B) |
| audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B) |
| audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1) |
| batch_size_c, N_t, C_a = audio_embeds_c.shape |
| audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a) |
|
|
| |
| audio_embeds_c = torch.relu(self.proj2(audio_embeds_c)) |
|
|
| context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.output_dim) |
|
|
| |
| context_tokens = self.norm(context_tokens) |
| context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) |
|
|
| return context_tokens |
|
|
| |
| class SingleStreamAttention(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| encoder_hidden_states_dim: int, |
| num_heads: int, |
| qkv_bias: bool, |
| attention_mode: str = 'sdpa', |
| ) -> None: |
| super().__init__() |
| assert dim % num_heads == 0, "dim should be divisible by num_heads" |
| self.dim = dim |
| self.encoder_hidden_states_dim = encoder_hidden_states_dim |
| self.num_heads = num_heads |
| self.head_dim = dim // num_heads |
| self.attention_mode = attention_mode |
|
|
| self.q_linear = nn.Linear(dim, dim, bias=qkv_bias) |
| self.proj = nn.Linear(dim, dim) |
| self.kv_linear = nn.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias) |
|
|
| def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor: |
| N_t, N_h, N_w = shape |
|
|
| expected_tokens = N_t * N_h * N_w |
| actual_tokens = x.shape[1] |
| x_extra = None |
|
|
| if actual_tokens != expected_tokens: |
| x_extra = x[:, -N_h * N_w:, :] |
| x = x[:, :-N_h * N_w, :] |
| N_t = N_t - 1 |
|
|
| B = x.shape[0] |
| S = N_h * N_w |
| x = x.view(B * N_t, S, self.dim) |
|
|
| |
| q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim) |
| |
| |
| kv = self.kv_linear(encoder_hidden_states) |
| encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2) |
|
|
| x = attention(q, encoder_k, encoder_v, attention_mode=self.attention_mode) |
|
|
| |
| x = self.proj(x.reshape(B * N_t, S, self.dim)) |
| x = x.view(B, N_t * S, self.dim) |
| |
| if x_extra is not None: |
| x = torch.cat([x, torch.zeros_like(x_extra)], dim=1) |
|
|
| return x |
|
|
| |
| class SingleStreamMultiAttention(SingleStreamAttention): |
| """Multi-speaker rotary-position cross-attention. |
| |
| This implementation generalises the original 2-speaker logic to an arbitrary |
| number of voices. Each speaker is allocated a contiguous *class_interval* |
| segment inside a shared *class_range* rotary bucket. The centre of each |
| bucket is applied to that speaker's KV tokens while queries are modulated |
| per-token according to which speaker dominates the pixel. |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| encoder_hidden_states_dim: int, |
| num_heads: int, |
| qkv_bias: bool, |
| class_range: int = 24, |
| class_interval: int = 4, |
| attention_mode: str = 'sdpa', |
| ) -> None: |
| super().__init__( |
| dim=dim, |
| encoder_hidden_states_dim=encoder_hidden_states_dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| attention_mode=attention_mode, |
| ) |
|
|
| |
| self.class_interval = class_interval |
| self.class_range = class_range |
| self.max_humans = self.class_range // self.class_interval |
|
|
| |
| self.rope_bak = int(self.class_range // 2) |
|
|
| self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim) |
|
|
| self.attention_mode = attention_mode |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| shape=None, |
| x_ref_attn_map=None, |
| human_num=None, |
| ) -> torch.Tensor: |
| encoder_hidden_states = encoder_hidden_states.squeeze(0) |
|
|
| |
| if human_num is None or human_num <= 1: |
| return super().forward(x, encoder_hidden_states, shape) |
|
|
| N_t, N_h, N_w = shape |
| |
| x_extra = None |
| if x.shape[0] * N_t != encoder_hidden_states.shape[0]: |
| x_extra = x[:, -N_h * N_w:, :] |
| x = x[:, :-N_h * N_w, :] |
| N_t = N_t - 1 |
| x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) |
|
|
| |
| B, N, C = x.shape |
| q = self.q_linear(x) |
| q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) |
|
|
| if human_num == 2: |
| |
| rope_h1 = (0, self.class_interval) |
| rope_h2 = (self.class_range - self.class_interval, self.class_range) |
| rope_bak = int(self.class_range // 2) |
|
|
| |
| max_values = x_ref_attn_map.max(1).values[:, None, None] |
| min_values = x_ref_attn_map.min(1).values[:, None, None] |
| max_min_values = torch.cat([max_values, min_values], dim=2) |
|
|
| human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min() |
| human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min() |
|
|
| human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1) |
| human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2) |
| back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device) |
|
|
| |
| max_indices = x_ref_attn_map.argmax(dim=0) |
| normalized_map = torch.stack([human1, human2, back], dim=1) |
| normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices] |
| else: |
| |
| rope_ranges = [ |
| (i * self.class_interval, (i + 1) * self.class_interval) |
| for i in range(human_num) |
| ] |
|
|
| |
| human_norm_list = [] |
| for idx in range(human_num): |
| attn_map = x_ref_attn_map[idx] |
| att_min, att_max = attn_map.min(), attn_map.max() |
| human_norm = normalize_and_scale( |
| attn_map, (att_min, att_max), rope_ranges[idx] |
| ) |
| human_norm_list.append(human_norm) |
|
|
| |
| back = torch.full( |
| (x_ref_attn_map.size(1),), |
| self.rope_bak, |
| dtype=x_ref_attn_map.dtype, |
| device=x_ref_attn_map.device, |
| ) |
|
|
| |
| max_indices = x_ref_attn_map.argmax(dim=0) |
| normalized_map = torch.stack(human_norm_list + [back], dim=1) |
| normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices] |
|
|
| |
| q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) |
| q = self.rope_1d(q, normalized_pos) |
| q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) |
|
|
| |
| _, N_a, _ = encoder_hidden_states.shape |
| encoder_kv = self.kv_linear(encoder_hidden_states) |
| encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| encoder_k, encoder_v = encoder_kv.unbind(0) |
|
|
| |
| if human_num == 2: |
| per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device) |
| per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2 |
| per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2 |
| encoder_pos = torch.cat([per_frame] * N_t, dim=0) |
| else: |
| tokens_per_human = N_a // human_num |
| encoder_pos_list = [] |
| for i in range(human_num): |
| start, end = rope_ranges[i] |
| centre = (start + end) / 2 |
| encoder_pos_list.append( |
| torch.full( |
| (tokens_per_human,), centre, dtype=encoder_k.dtype, device=encoder_k.device |
| ) |
| ) |
| encoder_pos = torch.cat(encoder_pos_list * N_t, dim=0) |
|
|
| encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) |
| encoder_k = self.rope_1d(encoder_k, encoder_pos) |
| encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) |
|
|
| |
| q = rearrange(q, "B H M K -> B M H K") |
| encoder_k = rearrange(encoder_k, "B H M K -> B M H K") |
| encoder_v = rearrange(encoder_v, "B H M K -> B M H K") |
| x = attention( |
| q, encoder_k, encoder_v, attention_mode=self.attention_mode |
| ) |
|
|
| |
| x = x.reshape(B, N, C) |
| x = self.proj(x) |
|
|
| |
| x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t) |
| if x_extra is not None: |
| x = torch.cat([x, torch.zeros_like(x_extra)], dim=1) |
|
|
| return x |