File size: 6,491 Bytes
3858c75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
HiFi-WaveGAN Training Script (Modular version)

Usage:
  python train.py --data_dir /path/to/audio --batch_size 8 --total_steps 200000

See train_hifi_wavegan.py for the self-contained single-file version.
"""

import os, sys, time, json, argparse
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import ExponentialLR
from torch.cuda.amp import GradScaler, autocast

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from hifi_wavegan.models.generator import ExWaveNetGenerator
from hifi_wavegan.models.discriminator import MultiPeriodDiscriminator, MultiResolutionSpectrogramDiscriminator
from hifi_wavegan.losses import HiFiWaveGANLoss
from hifi_wavegan.dataset import VocoderDataset
from hifi_wavegan.config import HiFiWaveGANConfig

def count_params(m): return sum(p.numel() for p in m.parameters() if p.requires_grad)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", default="/app/data/gtsinger")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--total_steps", type=int, default=200000)
    parser.add_argument("--hub_model_id", default="Frazun09/hifi-wavegan-48khz")
    args = parser.parse_args()

    cfg = HiFiWaveGANConfig()
    cfg.training.data_dir = args.data_dir
    cfg.training.batch_size = args.batch_size
    cfg.training.total_steps = args.total_steps
    cfg.training.hub_model_id = args.hub_model_id

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")

    # Download GTSinger if needed
    from huggingface_hub import snapshot_download
    if not os.path.exists(cfg.training.data_dir) or len(list(Path(cfg.training.data_dir).rglob("*.wav"))) < 100:
        snapshot_download("AaronZ345/GTSinger", repo_type="dataset", local_dir=cfg.training.data_dir,
                          ignore_patterns=["*.md","*.json","*.py","*.txt","*Paired_Speech*"])

    gen = ExWaveNetGenerator(cfg.generator.n_mels, cfg.generator.residual_ch, cfg.generator.skip_ch,
                              cfg.generator.n_stacks, cfg.generator.n_layers_per_stack,
                              cfg.generator.kernel_sizes, cfg.generator.hop_size,
                              cfg.generator.sample_rate, cfg.generator.use_pulse).to(device)
    mpd = MultiPeriodDiscriminator(cfg.discriminator.mpd_periods).to(device)
    mrsd = MultiResolutionSpectrogramDiscriminator(cfg.discriminator.mrsd_stft_params).to(device)
    print(f"Generator: {count_params(gen)/1e6:.2f}M | MPD: {count_params(mpd)/1e6:.2f}M | MRSD: {count_params(mrsd)/1e6:.2f}M")

    loss_fn = HiFiWaveGANLoss(cfg.loss.lambda_adv, cfg.loss.lambda_aux, cfg.loss.lambda_fm)
    og = AdamW(gen.parameters(), lr=cfg.training.lr, betas=cfg.training.betas, weight_decay=cfg.training.weight_decay)
    od = AdamW(list(mpd.parameters())+list(mrsd.parameters()), lr=cfg.training.lr, betas=cfg.training.betas, weight_decay=cfg.training.weight_decay)
    sg, sd = ExponentialLR(og, cfg.training.lr_decay), ExponentialLR(od, cfg.training.lr_decay)
    scg, scd = GradScaler(enabled=device.type=='cuda'), GradScaler(enabled=device.type=='cuda')

    train_ds = VocoderDataset(cfg.training.data_dir, cfg.audio.segment_size, cfg.audio.sample_rate,
                               cfg.audio.n_mels, cfg.audio.hop_length, cfg.audio.win_length, cfg.audio.n_fft, 'train')
    train_dl = torch.utils.data.DataLoader(train_ds, cfg.training.batch_size, shuffle=True,
                                            num_workers=cfg.training.num_workers, pin_memory=True, drop_last=True)

    gen.train(); mpd.train(); mrsd.train()
    step, epoch, it = 0, 0, iter(train_dl)
    t0 = time.time()

    while step < cfg.training.total_steps:
        try: batch = next(it)
        except StopIteration:
            epoch += 1; it = iter(train_dl); batch = next(it); sg.step(); sd.step()

        wav, mel, pitch = batch['wav'].to(device), batch['mel'].to(device), batch['pitch'].to(device)
        f0, uv = batch['f0'].to(device), batch['uv'].to(device)
        B, _, Tf = mel.shape

        od.zero_grad()
        with autocast(enabled=device.type=='cuda', device_type=device.type):
            z = torch.randn(B, 1, Tf, device=device)
            with torch.no_grad(): fake = gen(z, mel, pitch, f0, uv)
            ml = min(wav.shape[-1], fake.shape[-1])
            rw, fw = wav[...,:ml], fake[...,:ml]
            ro, _ = mpd(rw); fo, _ = mpd(fw)
            rso, _ = mrsd(rw); fso, _ = mrsd(fw)
            dl, dd = loss_fn.discriminator_loss(ro, fo, rso, fso)
        scd.scale(dl).backward(); scd.unscale_(od)
        torch.nn.utils.clip_grad_norm_(list(mpd.parameters())+list(mrsd.parameters()), cfg.training.grad_clip)
        scd.step(od); scd.update()

        og.zero_grad()
        with autocast(enabled=device.type=='cuda', device_type=device.type):
            z = torch.randn(B, 1, Tf, device=device); fake = gen(z, mel, pitch, f0, uv)
            ml = min(wav.shape[-1], fake.shape[-1])
            rw, fw = wav[...,:ml], fake[...,:ml]
            ro, rf = mpd(rw); fo, ff = mpd(fw)
            rso, rsf = mrsd(rw); fso, fsf = mrsd(fw)
            gl, gd = loss_fn.generator_loss(fo, ff, fso, fsf, rf, rsf, fw, rw)
        scg.scale(gl).backward(); scg.unscale_(og)
        torch.nn.utils.clip_grad_norm_(gen.parameters(), cfg.training.grad_clip)
        scg.step(og); scg.update()
        step += 1

        if step % cfg.training.log_interval == 0:
            eta = (cfg.training.total_steps - step) / max((step)/(time.time()-t0), 1e-6) / 3600
            print(f"Step {step}/{cfg.training.total_steps} | G={gd['total']:.2f} adv={gd['adv']:.3f} aux={gd['aux']:.3f} fm={gd['fm']:.3f} | D={dd['total']:.3f} | ETA={eta:.1f}h")

        if step % cfg.training.save_interval == 0:
            os.makedirs(cfg.training.checkpoint_dir, exist_ok=True)
            torch.save({'gen': gen.state_dict(), 'mpd': mpd.state_dict(), 'mrsd': mrsd.state_dict(),
                        'og': og.state_dict(), 'od': od.state_dict(), 'step': step},
                       f"{cfg.training.checkpoint_dir}/step_{step}.pt")

    # Push to Hub
    from huggingface_hub import HfApi
    torch.save(gen.state_dict(), "/app/generator.pt")
    HfApi().upload_file("/app/generator.pt", cfg.training.hub_model_id, "generator.pt")
    print(f"Pushed to https://huggingface.co/{cfg.training.hub_model_id}")

if __name__ == "__main__":
    main()