hifi-wavegan-48khz / inference.py
Frazun09's picture
Add inference script
6c5af53 verified
"""
HiFi-WaveGAN Inference — Generate 48kHz singing voice from mel-spectrogram.
Usage:
python inference.py --input singing.wav --output generated.wav --checkpoint generator.pt
"""
import os, sys, time, argparse
import torch, torchaudio
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hifi_wavegan.models.generator import ExWaveNetGenerator
from hifi_wavegan.dataset import MelSpectrogramTransform, estimate_f0
from hifi_wavegan.config import HiFiWaveGANConfig
@torch.no_grad()
def synthesize(gen, path_in, path_out, sr=48000, n_mels=120, hop=240, win=960, n_fft=2048, device='cpu'):
wav, orig_sr = torchaudio.load(path_in)
if wav.shape[0] > 1: wav = wav.mean(0, keepdim=True)
if orig_sr != sr: wav = torchaudio.transforms.Resample(orig_sr, sr)(wav)
wav = wav.squeeze(0); wav = wav / (wav.abs().max()+1e-8) * 0.95
mel = MelSpectrogramTransform(sr, n_fft, win, hop, n_mels)(wav).unsqueeze(0).to(device)
f0, uv = estimate_f0(wav.cpu(), sr, hop)
f0, uv = f0.unsqueeze(0).to(device), uv.unsqueeze(0).to(device)
pitch = ((torch.log(f0.clamp(min=1.0))-4.5)/1.5*uv).unsqueeze(1)
t0 = time.time()
out = gen.inference(mel, pitch, f0, uv).squeeze().cpu()
print(f"RTF: {(time.time()-t0)/(len(wav)/sr):.4f}")
os.makedirs(os.path.dirname(path_out) or '.', exist_ok=True)
torchaudio.save(path_out, out.unsqueeze(0), sr)
print(f"Saved: {path_out}")
def main():
p = argparse.ArgumentParser()
p.add_argument("--input", required=True); p.add_argument("--output", required=True)
p.add_argument("--checkpoint", required=True); p.add_argument("--device", default="auto")
a = p.parse_args()
dev = 'cuda' if a.device=='auto' and torch.cuda.is_available() else (a.device if a.device!='auto' else 'cpu')
cfg = HiFiWaveGANConfig()
gen = ExWaveNetGenerator(cfg.generator.n_mels, cfg.generator.residual_ch, cfg.generator.skip_ch,
cfg.generator.n_stacks, cfg.generator.n_layers_per_stack,
cfg.generator.kernel_sizes, cfg.generator.hop_size,
cfg.generator.sample_rate, cfg.generator.use_pulse).to(dev)
st = torch.load(a.checkpoint, map_location=dev)
gen.load_state_dict(st.get('gen', st.get('generator', st))); gen.eval()
synthesize(gen, a.input, a.output, device=dev)
if __name__ == "__main__": main()