| """ |
| HiFi-WaveGAN Training Script (Modular version) |
| |
| Usage: |
| python train.py --data_dir /path/to/audio --batch_size 8 --total_steps 200000 |
| |
| See train_hifi_wavegan.py for the self-contained single-file version. |
| """ |
|
|
| import os, sys, time, json, argparse |
| from pathlib import Path |
| import torch |
| import torch.nn.functional as F |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import ExponentialLR |
| from torch.cuda.amp import GradScaler, autocast |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from hifi_wavegan.models.generator import ExWaveNetGenerator |
| from hifi_wavegan.models.discriminator import MultiPeriodDiscriminator, MultiResolutionSpectrogramDiscriminator |
| from hifi_wavegan.losses import HiFiWaveGANLoss |
| from hifi_wavegan.dataset import VocoderDataset |
| from hifi_wavegan.config import HiFiWaveGANConfig |
|
|
| def count_params(m): return sum(p.numel() for p in m.parameters() if p.requires_grad) |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--data_dir", default="/app/data/gtsinger") |
| parser.add_argument("--batch_size", type=int, default=8) |
| parser.add_argument("--total_steps", type=int, default=200000) |
| parser.add_argument("--hub_model_id", default="Frazun09/hifi-wavegan-48khz") |
| args = parser.parse_args() |
|
|
| cfg = HiFiWaveGANConfig() |
| cfg.training.data_dir = args.data_dir |
| cfg.training.batch_size = args.batch_size |
| cfg.training.total_steps = args.total_steps |
| cfg.training.hub_model_id = args.hub_model_id |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
|
|
| |
| from huggingface_hub import snapshot_download |
| if not os.path.exists(cfg.training.data_dir) or len(list(Path(cfg.training.data_dir).rglob("*.wav"))) < 100: |
| snapshot_download("AaronZ345/GTSinger", repo_type="dataset", local_dir=cfg.training.data_dir, |
| ignore_patterns=["*.md","*.json","*.py","*.txt","*Paired_Speech*"]) |
|
|
| gen = ExWaveNetGenerator(cfg.generator.n_mels, cfg.generator.residual_ch, cfg.generator.skip_ch, |
| cfg.generator.n_stacks, cfg.generator.n_layers_per_stack, |
| cfg.generator.kernel_sizes, cfg.generator.hop_size, |
| cfg.generator.sample_rate, cfg.generator.use_pulse).to(device) |
| mpd = MultiPeriodDiscriminator(cfg.discriminator.mpd_periods).to(device) |
| mrsd = MultiResolutionSpectrogramDiscriminator(cfg.discriminator.mrsd_stft_params).to(device) |
| print(f"Generator: {count_params(gen)/1e6:.2f}M | MPD: {count_params(mpd)/1e6:.2f}M | MRSD: {count_params(mrsd)/1e6:.2f}M") |
|
|
| loss_fn = HiFiWaveGANLoss(cfg.loss.lambda_adv, cfg.loss.lambda_aux, cfg.loss.lambda_fm) |
| og = AdamW(gen.parameters(), lr=cfg.training.lr, betas=cfg.training.betas, weight_decay=cfg.training.weight_decay) |
| od = AdamW(list(mpd.parameters())+list(mrsd.parameters()), lr=cfg.training.lr, betas=cfg.training.betas, weight_decay=cfg.training.weight_decay) |
| sg, sd = ExponentialLR(og, cfg.training.lr_decay), ExponentialLR(od, cfg.training.lr_decay) |
| scg, scd = GradScaler(enabled=device.type=='cuda'), GradScaler(enabled=device.type=='cuda') |
|
|
| train_ds = VocoderDataset(cfg.training.data_dir, cfg.audio.segment_size, cfg.audio.sample_rate, |
| cfg.audio.n_mels, cfg.audio.hop_length, cfg.audio.win_length, cfg.audio.n_fft, 'train') |
| train_dl = torch.utils.data.DataLoader(train_ds, cfg.training.batch_size, shuffle=True, |
| num_workers=cfg.training.num_workers, pin_memory=True, drop_last=True) |
|
|
| gen.train(); mpd.train(); mrsd.train() |
| step, epoch, it = 0, 0, iter(train_dl) |
| t0 = time.time() |
|
|
| while step < cfg.training.total_steps: |
| try: batch = next(it) |
| except StopIteration: |
| epoch += 1; it = iter(train_dl); batch = next(it); sg.step(); sd.step() |
|
|
| wav, mel, pitch = batch['wav'].to(device), batch['mel'].to(device), batch['pitch'].to(device) |
| f0, uv = batch['f0'].to(device), batch['uv'].to(device) |
| B, _, Tf = mel.shape |
|
|
| od.zero_grad() |
| with autocast(enabled=device.type=='cuda', device_type=device.type): |
| z = torch.randn(B, 1, Tf, device=device) |
| with torch.no_grad(): fake = gen(z, mel, pitch, f0, uv) |
| ml = min(wav.shape[-1], fake.shape[-1]) |
| rw, fw = wav[...,:ml], fake[...,:ml] |
| ro, _ = mpd(rw); fo, _ = mpd(fw) |
| rso, _ = mrsd(rw); fso, _ = mrsd(fw) |
| dl, dd = loss_fn.discriminator_loss(ro, fo, rso, fso) |
| scd.scale(dl).backward(); scd.unscale_(od) |
| torch.nn.utils.clip_grad_norm_(list(mpd.parameters())+list(mrsd.parameters()), cfg.training.grad_clip) |
| scd.step(od); scd.update() |
|
|
| og.zero_grad() |
| with autocast(enabled=device.type=='cuda', device_type=device.type): |
| z = torch.randn(B, 1, Tf, device=device); fake = gen(z, mel, pitch, f0, uv) |
| ml = min(wav.shape[-1], fake.shape[-1]) |
| rw, fw = wav[...,:ml], fake[...,:ml] |
| ro, rf = mpd(rw); fo, ff = mpd(fw) |
| rso, rsf = mrsd(rw); fso, fsf = mrsd(fw) |
| gl, gd = loss_fn.generator_loss(fo, ff, fso, fsf, rf, rsf, fw, rw) |
| scg.scale(gl).backward(); scg.unscale_(og) |
| torch.nn.utils.clip_grad_norm_(gen.parameters(), cfg.training.grad_clip) |
| scg.step(og); scg.update() |
| step += 1 |
|
|
| if step % cfg.training.log_interval == 0: |
| eta = (cfg.training.total_steps - step) / max((step)/(time.time()-t0), 1e-6) / 3600 |
| print(f"Step {step}/{cfg.training.total_steps} | G={gd['total']:.2f} adv={gd['adv']:.3f} aux={gd['aux']:.3f} fm={gd['fm']:.3f} | D={dd['total']:.3f} | ETA={eta:.1f}h") |
|
|
| if step % cfg.training.save_interval == 0: |
| os.makedirs(cfg.training.checkpoint_dir, exist_ok=True) |
| torch.save({'gen': gen.state_dict(), 'mpd': mpd.state_dict(), 'mrsd': mrsd.state_dict(), |
| 'og': og.state_dict(), 'od': od.state_dict(), 'step': step}, |
| f"{cfg.training.checkpoint_dir}/step_{step}.pt") |
|
|
| |
| from huggingface_hub import HfApi |
| torch.save(gen.state_dict(), "/app/generator.pt") |
| HfApi().upload_file("/app/generator.pt", cfg.training.hub_model_id, "generator.pt") |
| print(f"Pushed to https://huggingface.co/{cfg.training.hub_model_id}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|