Spaces:
Running on Zero
Running on Zero
| import math | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| def get_sample_indices(original_fps, total_frames, target_fps, num_sample, fixed_start=None): | |
| required_duration = num_sample / target_fps | |
| required_origin_frames = int(np.ceil(required_duration * original_fps)) | |
| if required_duration > total_frames / original_fps: | |
| raise ValueError("required_duration must be less than video length") | |
| if not fixed_start is None and fixed_start >= 0: | |
| start_frame = fixed_start | |
| else: | |
| max_start = total_frames - required_origin_frames | |
| if max_start < 0: | |
| raise ValueError("video length is too short") | |
| start_frame = np.random.randint(0, max_start + 1) | |
| start_time = start_frame / original_fps | |
| end_time = start_time + required_duration | |
| time_points = np.linspace(start_time, end_time, num_sample, endpoint=False) | |
| frame_indices = np.round(np.array(time_points) * original_fps).astype(int) | |
| frame_indices = np.clip(frame_indices, 0, total_frames - 1) | |
| return frame_indices | |
| def linear_interpolation(features, input_fps, output_fps, output_len=None): | |
| """ | |
| features: shape=[1, T, 512] | |
| input_fps: fps for audio, f_a | |
| output_fps: fps for video, f_m | |
| output_len: video length | |
| """ | |
| features = features.transpose(1, 2) | |
| seq_len = features.shape[2] / float(input_fps) | |
| if output_len is None: | |
| output_len = int(seq_len * output_fps) | |
| output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear') # [1, 512, output_len] | |
| return output_features.transpose(1, 2) | |
| class WanS2VAudioEncoder(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| from transformers import Wav2Vec2ForCTC, Wav2Vec2Config | |
| config = { | |
| "_name_or_path": "facebook/wav2vec2-large-xlsr-53", | |
| "activation_dropout": 0.05, | |
| "apply_spec_augment": True, | |
| "architectures": ["Wav2Vec2ForCTC"], | |
| "attention_dropout": 0.1, | |
| "bos_token_id": 1, | |
| "conv_bias": True, | |
| "conv_dim": [512, 512, 512, 512, 512, 512, 512], | |
| "conv_kernel": [10, 3, 3, 3, 3, 2, 2], | |
| "conv_stride": [5, 2, 2, 2, 2, 2, 2], | |
| "ctc_loss_reduction": "mean", | |
| "ctc_zero_infinity": True, | |
| "do_stable_layer_norm": True, | |
| "eos_token_id": 2, | |
| "feat_extract_activation": "gelu", | |
| "feat_extract_dropout": 0.0, | |
| "feat_extract_norm": "layer", | |
| "feat_proj_dropout": 0.05, | |
| "final_dropout": 0.0, | |
| "hidden_act": "gelu", | |
| "hidden_dropout": 0.05, | |
| "hidden_size": 1024, | |
| "initializer_range": 0.02, | |
| "intermediate_size": 4096, | |
| "layer_norm_eps": 1e-05, | |
| "layerdrop": 0.05, | |
| "mask_channel_length": 10, | |
| "mask_channel_min_space": 1, | |
| "mask_channel_other": 0.0, | |
| "mask_channel_prob": 0.0, | |
| "mask_channel_selection": "static", | |
| "mask_feature_length": 10, | |
| "mask_feature_prob": 0.0, | |
| "mask_time_length": 10, | |
| "mask_time_min_space": 1, | |
| "mask_time_other": 0.0, | |
| "mask_time_prob": 0.05, | |
| "mask_time_selection": "static", | |
| "model_type": "wav2vec2", | |
| "num_attention_heads": 16, | |
| "num_conv_pos_embedding_groups": 16, | |
| "num_conv_pos_embeddings": 128, | |
| "num_feat_extract_layers": 7, | |
| "num_hidden_layers": 24, | |
| "pad_token_id": 0, | |
| "transformers_version": "4.7.0.dev0", | |
| "vocab_size": 33 | |
| } | |
| self.model = Wav2Vec2ForCTC(Wav2Vec2Config(**config)) | |
| self.video_rate = 30 | |
| def extract_audio_feat(self, input_audio, sample_rate, processor, return_all_layers=False, dtype=torch.float32, device='cpu'): | |
| input_values = processor(input_audio, sampling_rate=sample_rate, return_tensors="pt").input_values.to(dtype=dtype, device=device) | |
| # retrieve logits & take argmax | |
| res = self.model(input_values, output_hidden_states=True) | |
| if return_all_layers: | |
| feat = torch.cat(res.hidden_states) | |
| else: | |
| feat = res.hidden_states[-1] | |
| feat = linear_interpolation(feat, input_fps=50, output_fps=self.video_rate) | |
| return feat | |
| def get_audio_embed_bucket(self, audio_embed, stride=2, batch_frames=12, m=2): | |
| num_layers, audio_frame_num, audio_dim = audio_embed.shape | |
| if num_layers > 1: | |
| return_all_layers = True | |
| else: | |
| return_all_layers = False | |
| min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1 | |
| bucket_num = min_batch_num * batch_frames | |
| batch_idx = [stride * i for i in range(bucket_num)] | |
| batch_audio_eb = [] | |
| for bi in batch_idx: | |
| if bi < audio_frame_num: | |
| audio_sample_stride = 2 | |
| chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride)) | |
| chosen_idx = [0 if c < 0 else c for c in chosen_idx] | |
| chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx] | |
| if return_all_layers: | |
| frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1) | |
| else: | |
| frame_audio_embed = audio_embed[0][chosen_idx].flatten() | |
| else: | |
| frame_audio_embed = \ | |
| torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ | |
| else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) | |
| batch_audio_eb.append(frame_audio_embed) | |
| batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) | |
| return batch_audio_eb, min_batch_num | |
| def get_audio_embed_bucket_fps(self, audio_embed, fps=16, batch_frames=81, m=0): | |
| num_layers, audio_frame_num, audio_dim = audio_embed.shape | |
| if num_layers > 1: | |
| return_all_layers = True | |
| else: | |
| return_all_layers = False | |
| scale = self.video_rate / fps | |
| min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1 | |
| bucket_num = min_batch_num * batch_frames | |
| padd_audio_num = math.ceil(min_batch_num * batch_frames / fps * self.video_rate) - audio_frame_num | |
| batch_idx = get_sample_indices( | |
| original_fps=self.video_rate, total_frames=audio_frame_num + padd_audio_num, target_fps=fps, num_sample=bucket_num, fixed_start=0 | |
| ) | |
| batch_audio_eb = [] | |
| audio_sample_stride = int(self.video_rate / fps) | |
| for bi in batch_idx: | |
| if bi < audio_frame_num: | |
| chosen_idx = list(range(bi - m * audio_sample_stride, bi + (m + 1) * audio_sample_stride, audio_sample_stride)) | |
| chosen_idx = [0 if c < 0 else c for c in chosen_idx] | |
| chosen_idx = [audio_frame_num - 1 if c >= audio_frame_num else c for c in chosen_idx] | |
| if return_all_layers: | |
| frame_audio_embed = audio_embed[:, chosen_idx].flatten(start_dim=-2, end_dim=-1) | |
| else: | |
| frame_audio_embed = audio_embed[0][chosen_idx].flatten() | |
| else: | |
| frame_audio_embed = \ | |
| torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \ | |
| else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device) | |
| batch_audio_eb.append(frame_audio_embed) | |
| batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb], dim=0) | |
| return batch_audio_eb, min_batch_num | |
| def get_audio_feats_per_inference(self, input_audio, sample_rate, processor, fps=16, batch_frames=80, m=0, dtype=torch.float32, device='cpu'): | |
| audio_feat = self.extract_audio_feat(input_audio, sample_rate, processor, return_all_layers=True, dtype=dtype, device=device) | |
| audio_embed_bucket, min_batch_num = self.get_audio_embed_bucket_fps(audio_feat, fps=fps, batch_frames=batch_frames, m=m) | |
| audio_embed_bucket = audio_embed_bucket.unsqueeze(0).permute(0, 2, 3, 1).to(device, dtype) | |
| audio_embeds = [audio_embed_bucket[..., i * batch_frames:(i + 1) * batch_frames] for i in range(min_batch_num)] | |
| return audio_embeds | |