File size: 2,387 Bytes
6c5af53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""
HiFi-WaveGAN Inference — Generate 48kHz singing voice from mel-spectrogram.

Usage:
  python inference.py --input singing.wav --output generated.wav --checkpoint generator.pt
"""
import os, sys, time, argparse
import torch, torchaudio

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hifi_wavegan.models.generator import ExWaveNetGenerator
from hifi_wavegan.dataset import MelSpectrogramTransform, estimate_f0
from hifi_wavegan.config import HiFiWaveGANConfig

@torch.no_grad()
def synthesize(gen, path_in, path_out, sr=48000, n_mels=120, hop=240, win=960, n_fft=2048, device='cpu'):
    wav, orig_sr = torchaudio.load(path_in)
    if wav.shape[0] > 1: wav = wav.mean(0, keepdim=True)
    if orig_sr != sr: wav = torchaudio.transforms.Resample(orig_sr, sr)(wav)
    wav = wav.squeeze(0); wav = wav / (wav.abs().max()+1e-8) * 0.95
    mel = MelSpectrogramTransform(sr, n_fft, win, hop, n_mels)(wav).unsqueeze(0).to(device)
    f0, uv = estimate_f0(wav.cpu(), sr, hop)
    f0, uv = f0.unsqueeze(0).to(device), uv.unsqueeze(0).to(device)
    pitch = ((torch.log(f0.clamp(min=1.0))-4.5)/1.5*uv).unsqueeze(1)
    t0 = time.time()
    out = gen.inference(mel, pitch, f0, uv).squeeze().cpu()
    print(f"RTF: {(time.time()-t0)/(len(wav)/sr):.4f}")
    os.makedirs(os.path.dirname(path_out) or '.', exist_ok=True)
    torchaudio.save(path_out, out.unsqueeze(0), sr)
    print(f"Saved: {path_out}")

def main():
    p = argparse.ArgumentParser()
    p.add_argument("--input", required=True); p.add_argument("--output", required=True)
    p.add_argument("--checkpoint", required=True); p.add_argument("--device", default="auto")
    a = p.parse_args()
    dev = 'cuda' if a.device=='auto' and torch.cuda.is_available() else (a.device if a.device!='auto' else 'cpu')
    cfg = HiFiWaveGANConfig()
    gen = ExWaveNetGenerator(cfg.generator.n_mels, cfg.generator.residual_ch, cfg.generator.skip_ch,
                              cfg.generator.n_stacks, cfg.generator.n_layers_per_stack,
                              cfg.generator.kernel_sizes, cfg.generator.hop_size,
                              cfg.generator.sample_rate, cfg.generator.use_pulse).to(dev)
    st = torch.load(a.checkpoint, map_location=dev)
    gen.load_state_dict(st.get('gen', st.get('generator', st))); gen.eval()
    synthesize(gen, a.input, a.output, device=dev)

if __name__ == "__main__": main()