""" HiFi-WaveGAN Training Script (Modular version) Usage: python train.py --data_dir /path/to/audio --batch_size 8 --total_steps 200000 See train_hifi_wavegan.py for the self-contained single-file version. """ import os, sys, time, json, argparse from pathlib import Path import torch import torch.nn.functional as F from torch.optim import AdamW from torch.optim.lr_scheduler import ExponentialLR from torch.cuda.amp import GradScaler, autocast sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from hifi_wavegan.models.generator import ExWaveNetGenerator from hifi_wavegan.models.discriminator import MultiPeriodDiscriminator, MultiResolutionSpectrogramDiscriminator from hifi_wavegan.losses import HiFiWaveGANLoss from hifi_wavegan.dataset import VocoderDataset from hifi_wavegan.config import HiFiWaveGANConfig def count_params(m): return sum(p.numel() for p in m.parameters() if p.requires_grad) def main(): parser = argparse.ArgumentParser() parser.add_argument("--data_dir", default="/app/data/gtsinger") parser.add_argument("--batch_size", type=int, default=8) parser.add_argument("--total_steps", type=int, default=200000) parser.add_argument("--hub_model_id", default="Frazun09/hifi-wavegan-48khz") args = parser.parse_args() cfg = HiFiWaveGANConfig() cfg.training.data_dir = args.data_dir cfg.training.batch_size = args.batch_size cfg.training.total_steps = args.total_steps cfg.training.hub_model_id = args.hub_model_id device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") # Download GTSinger if needed from huggingface_hub import snapshot_download if not os.path.exists(cfg.training.data_dir) or len(list(Path(cfg.training.data_dir).rglob("*.wav"))) < 100: snapshot_download("AaronZ345/GTSinger", repo_type="dataset", local_dir=cfg.training.data_dir, ignore_patterns=["*.md","*.json","*.py","*.txt","*Paired_Speech*"]) gen = ExWaveNetGenerator(cfg.generator.n_mels, cfg.generator.residual_ch, cfg.generator.skip_ch, cfg.generator.n_stacks, cfg.generator.n_layers_per_stack, cfg.generator.kernel_sizes, cfg.generator.hop_size, cfg.generator.sample_rate, cfg.generator.use_pulse).to(device) mpd = MultiPeriodDiscriminator(cfg.discriminator.mpd_periods).to(device) mrsd = MultiResolutionSpectrogramDiscriminator(cfg.discriminator.mrsd_stft_params).to(device) print(f"Generator: {count_params(gen)/1e6:.2f}M | MPD: {count_params(mpd)/1e6:.2f}M | MRSD: {count_params(mrsd)/1e6:.2f}M") loss_fn = HiFiWaveGANLoss(cfg.loss.lambda_adv, cfg.loss.lambda_aux, cfg.loss.lambda_fm) og = AdamW(gen.parameters(), lr=cfg.training.lr, betas=cfg.training.betas, weight_decay=cfg.training.weight_decay) od = AdamW(list(mpd.parameters())+list(mrsd.parameters()), lr=cfg.training.lr, betas=cfg.training.betas, weight_decay=cfg.training.weight_decay) sg, sd = ExponentialLR(og, cfg.training.lr_decay), ExponentialLR(od, cfg.training.lr_decay) scg, scd = GradScaler(enabled=device.type=='cuda'), GradScaler(enabled=device.type=='cuda') train_ds = VocoderDataset(cfg.training.data_dir, cfg.audio.segment_size, cfg.audio.sample_rate, cfg.audio.n_mels, cfg.audio.hop_length, cfg.audio.win_length, cfg.audio.n_fft, 'train') train_dl = torch.utils.data.DataLoader(train_ds, cfg.training.batch_size, shuffle=True, num_workers=cfg.training.num_workers, pin_memory=True, drop_last=True) gen.train(); mpd.train(); mrsd.train() step, epoch, it = 0, 0, iter(train_dl) t0 = time.time() while step < cfg.training.total_steps: try: batch = next(it) except StopIteration: epoch += 1; it = iter(train_dl); batch = next(it); sg.step(); sd.step() wav, mel, pitch = batch['wav'].to(device), batch['mel'].to(device), batch['pitch'].to(device) f0, uv = batch['f0'].to(device), batch['uv'].to(device) B, _, Tf = mel.shape od.zero_grad() with autocast(enabled=device.type=='cuda', device_type=device.type): z = torch.randn(B, 1, Tf, device=device) with torch.no_grad(): fake = gen(z, mel, pitch, f0, uv) ml = min(wav.shape[-1], fake.shape[-1]) rw, fw = wav[...,:ml], fake[...,:ml] ro, _ = mpd(rw); fo, _ = mpd(fw) rso, _ = mrsd(rw); fso, _ = mrsd(fw) dl, dd = loss_fn.discriminator_loss(ro, fo, rso, fso) scd.scale(dl).backward(); scd.unscale_(od) torch.nn.utils.clip_grad_norm_(list(mpd.parameters())+list(mrsd.parameters()), cfg.training.grad_clip) scd.step(od); scd.update() og.zero_grad() with autocast(enabled=device.type=='cuda', device_type=device.type): z = torch.randn(B, 1, Tf, device=device); fake = gen(z, mel, pitch, f0, uv) ml = min(wav.shape[-1], fake.shape[-1]) rw, fw = wav[...,:ml], fake[...,:ml] ro, rf = mpd(rw); fo, ff = mpd(fw) rso, rsf = mrsd(rw); fso, fsf = mrsd(fw) gl, gd = loss_fn.generator_loss(fo, ff, fso, fsf, rf, rsf, fw, rw) scg.scale(gl).backward(); scg.unscale_(og) torch.nn.utils.clip_grad_norm_(gen.parameters(), cfg.training.grad_clip) scg.step(og); scg.update() step += 1 if step % cfg.training.log_interval == 0: eta = (cfg.training.total_steps - step) / max((step)/(time.time()-t0), 1e-6) / 3600 print(f"Step {step}/{cfg.training.total_steps} | G={gd['total']:.2f} adv={gd['adv']:.3f} aux={gd['aux']:.3f} fm={gd['fm']:.3f} | D={dd['total']:.3f} | ETA={eta:.1f}h") if step % cfg.training.save_interval == 0: os.makedirs(cfg.training.checkpoint_dir, exist_ok=True) torch.save({'gen': gen.state_dict(), 'mpd': mpd.state_dict(), 'mrsd': mrsd.state_dict(), 'og': og.state_dict(), 'od': od.state_dict(), 'step': step}, f"{cfg.training.checkpoint_dir}/step_{step}.pt") # Push to Hub from huggingface_hub import HfApi torch.save(gen.state_dict(), "/app/generator.pt") HfApi().upload_file("/app/generator.pt", cfg.training.hub_model_id, "generator.pt") print(f"Pushed to https://huggingface.co/{cfg.training.hub_model_id}") if __name__ == "__main__": main()