| |
|
|
| from typing import Sequence, Optional, Union |
| import sys |
| |
| import math |
| import random |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import torch.nn.functional as F |
|
|
| from modules.seanet import SEANetEncoder, SEANetDecoder |
| from quantization import ResidualVectorQuantizer |
| from transformers import AutoModel |
| |
| from transformers import AutoFeatureExtractor, WhisperModel |
| |
| from RepCodec.repcodec.modules.encoder import Encoder |
| from RepCodec.repcodec.modules.decoder import Decoder |
| |
| |
| import descriptaudiocodec.dac.model.dac as dac2 |
| |
| |
|
|
| def get_model_size(model): |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| |
| |
| model_size_bytes = total_params |
| |
| |
| model_size_mb = model_size_bytes / (1024 ** 2) |
| |
| return total_params, model_size_mb |
|
|
|
|
| class SoundStream(nn.Module): |
| """ SoundStream model or EnCodec model. |
| |
| Args: |
| n_filters (int): n_filters (int): Base width for the model. |
| D (int): Intermediate representation dimension. |
| target_bandwidths (Sequence[int]): Target bandwidths in K-bits/second. |
| ratios (Sequence[int]): downsampling factors, whose multiplication is the hop size. |
| sample_rate (int): wave sampling rate. |
| bins (int): number of code words in a codebook. |
| normalize (bool): audio normalization. |
| |
| """ |
| def __init__( |
| self, |
| n_filters: int = 32, |
| D: int = 128, |
| |
| target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6], |
| ratios: Sequence[int] = [8, 5, 4, 2], |
| sample_rate: int = 16000, |
| bins: int = 1024, |
| normalize: bool = False, |
| causal: bool = False, |
| ): |
| super().__init__() |
| self.hop_length = np.prod(ratios) |
| |
| n_q = int(1000 * target_bandwidths[-1] // (math.ceil(sample_rate / self.hop_length) * 10)) |
| self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) |
| self.bits_per_codebook = int(math.log2(bins)) |
| self.target_bandwidths = target_bandwidths |
| self.n_q = n_q |
| self.sample_rate = sample_rate |
|
|
| |
| |
| self.encoder = dac2.Encoder( 64,ratios,D) |
| |
| self.encoder_semantic = Encoder(input_channels=768,encode_channels=768) |
| self.decoder_semantic = Decoder(code_dim=768,output_channels=768,decode_channels=768) |
| |
| self.quantizer = ResidualVectorQuantizer(dimension=D+768, n_q=n_q, bins=bins) |
| |
| |
| |
| self.decoder_2 = dac2.Decoder( D,1024,ratios,) |
|
|
| |
| |
| |
| |
| |
| c=1 |
| |
| |
|
|
| self.is_semantic= True |
| if self.is_semantic: |
| |
| |
| self.semantic_model = AutoModel.from_pretrained("./xcodec_mini_infer/semantic_ckpts/hf_1_325000") |
| self.semantic_model.eval() |
| |
|
|
|
|
| |
| |
| |
| self.fc_prior = nn.Linear(D+768, D+768 ) |
| |
| self.fc_post1= nn.Linear( D+768, 768 ) |
| self.fc_post2= nn.Linear( D+768, D) |
|
|
| def get_last_layer(self): |
| return self.decoder.layers[-1].weight |
| |
| def calculate_rec_loss(self, rec, target): |
| |
| target = target / target.norm(dim=-1, keepdim=True) |
| rec = rec / rec.norm(dim=-1, keepdim=True) |
| rec_loss = (1 - (target * rec).sum(-1)).mean() |
| |
|
|
|
|
| return rec_loss |
|
|
| @torch.no_grad() |
| def get_regress_target(self, x ): |
| x= x[:,0,:] |
| x = F.pad(x, (160, 160)) |
| target = self.semantic_model(x, output_hidden_states=True) .hidden_states |
| target = torch.stack(target, dim=1) |
| |
| target = target.mean(1) |
| |
| return target |
|
|
| |
| def forward(self, x: torch.Tensor, bw: int): |
|
|
| e_semantic_input = self.get_regress_target_whisper(x).detach() |
|
|
| e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) |
| e_acoustic = self.encoder(x) |
| |
| |
| e= torch.cat([e_acoustic, e_semantic], dim=1) |
|
|
| e = self.fc_prior(e.transpose(1, 2)).transpose(1, 2) |
|
|
| |
| quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) |
|
|
| quantized_semantic = self.fc_post1(quantized.transpose(1, 2)).transpose(1, 2) |
| quantized_acoustic = self.fc_post2(quantized.transpose(1, 2)).transpose(1, 2) |
|
|
| o = self.decoder_2(quantized_acoustic) |
| |
| o_semantic = self.decoder_semantic(quantized_semantic ) |
| semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(),o_semantic) |
|
|
| return o, commit_loss, semantic_recon_loss,None |
| |
|
|
| def encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor: |
| |
| |
| |
| |
| bw = target_bw |
| |
|
|
| |
| |
| |
|
|
| e_semantic_input = self.get_regress_target(x).detach() |
|
|
| e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) |
| e_acoustic = self.encoder(x) |
|
|
|
|
| if e_acoustic.shape[2] != e_semantic.shape[2]: |
| |
| e_acoustic = self.encoder(torch.transpose(F.pad(x[:,0,:], (160, 160)).unsqueeze(0), 0, 1)) |
| |
| e= torch.cat([e_acoustic, e_semantic], dim=1) |
|
|
| e = self.fc_prior(e.transpose(1, 2)).transpose(1, 2) |
|
|
| quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) |
| return codes |
|
|
| def get_embed(self, codes: torch.Tensor) -> torch.Tensor: |
| return self.quantizer.decode(codes) |
|
|
| def decode(self, codes: torch.Tensor) -> torch.Tensor: |
| quantized = self.quantizer.decode(codes) |
| quantized_acoustic = self.fc_post2(quantized.transpose(1, 2)).transpose(1, 2) |
|
|
| o = self.decoder_2(quantized_acoustic) |
| return o |
|
|
| |
| if __name__ == '__main__': |
| soundstream = SoundStream(n_filters=32, D=256) |
| |
| for i in range(10): |
| print(f"Iter {i}: ") |
| x = torch.rand(1, 1, 16000) |
| o, commit_loss, distill_loss,_= soundstream(x,soundstream.target_bandwidths[-1]) |
| print('output', o.shape) |
|
|