| """ |
| HiFi-WaveGAN Inference — Generate 48kHz singing voice from mel-spectrogram. |
| |
| Usage: |
| python inference.py --input singing.wav --output generated.wav --checkpoint generator.pt |
| """ |
| import os, sys, time, argparse |
| import torch, torchaudio |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from hifi_wavegan.models.generator import ExWaveNetGenerator |
| from hifi_wavegan.dataset import MelSpectrogramTransform, estimate_f0 |
| from hifi_wavegan.config import HiFiWaveGANConfig |
|
|
| @torch.no_grad() |
| def synthesize(gen, path_in, path_out, sr=48000, n_mels=120, hop=240, win=960, n_fft=2048, device='cpu'): |
| wav, orig_sr = torchaudio.load(path_in) |
| if wav.shape[0] > 1: wav = wav.mean(0, keepdim=True) |
| if orig_sr != sr: wav = torchaudio.transforms.Resample(orig_sr, sr)(wav) |
| wav = wav.squeeze(0); wav = wav / (wav.abs().max()+1e-8) * 0.95 |
| mel = MelSpectrogramTransform(sr, n_fft, win, hop, n_mels)(wav).unsqueeze(0).to(device) |
| f0, uv = estimate_f0(wav.cpu(), sr, hop) |
| f0, uv = f0.unsqueeze(0).to(device), uv.unsqueeze(0).to(device) |
| pitch = ((torch.log(f0.clamp(min=1.0))-4.5)/1.5*uv).unsqueeze(1) |
| t0 = time.time() |
| out = gen.inference(mel, pitch, f0, uv).squeeze().cpu() |
| print(f"RTF: {(time.time()-t0)/(len(wav)/sr):.4f}") |
| os.makedirs(os.path.dirname(path_out) or '.', exist_ok=True) |
| torchaudio.save(path_out, out.unsqueeze(0), sr) |
| print(f"Saved: {path_out}") |
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--input", required=True); p.add_argument("--output", required=True) |
| p.add_argument("--checkpoint", required=True); p.add_argument("--device", default="auto") |
| a = p.parse_args() |
| dev = 'cuda' if a.device=='auto' and torch.cuda.is_available() else (a.device if a.device!='auto' else 'cpu') |
| cfg = HiFiWaveGANConfig() |
| gen = ExWaveNetGenerator(cfg.generator.n_mels, cfg.generator.residual_ch, cfg.generator.skip_ch, |
| cfg.generator.n_stacks, cfg.generator.n_layers_per_stack, |
| cfg.generator.kernel_sizes, cfg.generator.hop_size, |
| cfg.generator.sample_rate, cfg.generator.use_pulse).to(dev) |
| st = torch.load(a.checkpoint, map_location=dev) |
| gen.load_state_dict(st.get('gen', st.get('generator', st))); gen.eval() |
| synthesize(gen, a.input, a.output, device=dev) |
|
|
| if __name__ == "__main__": main() |
|
|