| import torch |
| from audiotools import AudioSignal |
| from audiotools.ml import BaseModel |
| from encodec import EncodecModel |
|
|
|
|
| class Encodec(BaseModel): |
| def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0): |
| super().__init__() |
|
|
| if sample_rate == 24000: |
| self.model = EncodecModel.encodec_model_24khz() |
| else: |
| self.model = EncodecModel.encodec_model_48khz() |
| self.model.set_target_bandwidth(bandwidth) |
| self.sample_rate = 44100 |
|
|
| def forward( |
| self, |
| audio_data: torch.Tensor, |
| sample_rate: int = 44100, |
| n_quantizers: int = None, |
| ): |
| signal = AudioSignal(audio_data, sample_rate) |
| signal.resample(self.model.sample_rate) |
| recons = self.model(signal.audio_data) |
| recons = AudioSignal(recons, self.model.sample_rate) |
| recons.resample(sample_rate) |
| return {"audio": recons.audio_data} |
|
|
|
|
| if __name__ == "__main__": |
| import numpy as np |
| from functools import partial |
|
|
| model = Encodec() |
|
|
| for n, m in model.named_modules(): |
| o = m.extra_repr() |
| p = sum([np.prod(p.size()) for p in m.parameters()]) |
| fn = lambda o, p: o + f" {p/1e6:<.3f}M params." |
| setattr(m, "extra_repr", partial(fn, o=o, p=p)) |
| print(model) |
| print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) |
|
|
| length = 88200 * 2 |
| x = torch.randn(1, 1, length).to(model.device) |
| x.requires_grad_(True) |
| x.retain_grad() |
|
|
| |
| out = model(x)["audio"] |
|
|
| print(x.shape, out.shape) |
|
|