| from transformers import VitsModel , VitsConfig |
| from torch import nn |
| from torch.nn.utils.parametrizations import weight_norm |
| from safetensors.torch import load_file |
| import torch |
|
|
| class ModVitsModel(VitsModel): |
| def __init__(self, config: VitsConfig): |
| config.num_speakers = len(config.emotion_names) * len(config.speaker_names) |
| super().__init__(config) |
| def init_weights(self): |
| self.decoder.upsampler = nn.ModuleList([weight_norm(layer) for layer in self.decoder.upsampler]) |
| for block in self.decoder.resblocks: |
| block.convs1 = nn.ModuleList([weight_norm(layer) for layer in block.convs1]) |
| block.convs2 = nn.ModuleList([weight_norm(layer) for layer in block.convs2]) |
| return super().init_weights() |
|
|
| @staticmethod |
| def _load_pretrained_model(model, state_dict, checkpoint_files, load_config): |
| state_dict = load_file(checkpoint_files[0]) |
| speakers = state_dict['embed_speaker.weight'][:len(model.config.speaker_names)] |
| emotions = state_dict['embed_emotion.weight'][:len(model.config.emotion_names)+len(model.config.undefined_emotion_index)][[i for i in range(len(model.config.emotion_names)+len(model.config.undefined_emotion_index)) if i not in model.config.undefined_emotion_index]] |
| state_dict['embed_speaker.weight'] = torch.stack([s + e for s in speakers for e in emotions]).reshape(-1, model.config.speaker_embedding_size) |
| del state_dict['embed_emotion.weight'] |
| return super()._load_pretrained_model(model, state_dict, checkpoint_files, load_config) |
| |
| @torch.inference_mode() |
| def forward(self, input_ids = None, attention_mask = None, speaker_id = None, output_attentions = None, output_hidden_states = None, return_dict = None, labels = None, **kwargs): |
| speaker_id = speaker_id * len(self.config.emotion_names) + kwargs['style_id'] |
| audio = super().forward(input_ids, attention_mask, speaker_id, output_attentions, output_hidden_states, return_dict, labels, **kwargs) |
| B, T = audio.waveform.shape |
| mask = torch.arange(T, device=audio.waveform.device).expand(B, T) < audio.sequence_lengths.unsqueeze(1) |
| audio.waveform.masked_fill_(~mask, 0) |
| return audio |
| |