| import torch |
| from torch import Tensor |
| from transformers import PreTrainedModel |
| from audio_diffusion_pytorch import DiffusionVocoder, UNetV0, VDiffusion, VSampler |
| from .config import VocoderConfig |
|
|
|
|
| class Vocoder(PreTrainedModel): |
|
|
| config_class = VocoderConfig |
|
|
| def __init__(self, config: VocoderConfig): |
| super().__init__(config) |
|
|
| self.model = DiffusionVocoder( |
| net_t=UNetV0, |
| mel_channels=80, |
| mel_n_fft=1024, |
| mel_sample_rate=48000, |
| mel_normalize_log=True, |
| channels=[8, 32, 64, 256, 256, 512, 512, 1024, 1024], |
| factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], |
| items=[1, 2, 2, 2, 2, 2, 2, 4, 4], |
| diffusion_t=VDiffusion, |
| sampler_t=VSampler |
| ) |
| |
| def to_spectrogram(self, *args, **kwargs): |
| return self.model.to_spectrogram(*args, **kwargs) |
|
|
| @torch.no_grad() |
| def sample(self, *args, **kwargs): |
| return self.model.sample(*args, **kwargs) |
| |
|
|
|
|