hifi-wavegan-48khz / train.py
Frazun09's picture
Add modular training script
3858c75 verified
"""
HiFi-WaveGAN Training Script (Modular version)
Usage:
python train.py --data_dir /path/to/audio --batch_size 8 --total_steps 200000
See train_hifi_wavegan.py for the self-contained single-file version.
"""
import os, sys, time, json, argparse
from pathlib import Path
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import ExponentialLR
from torch.cuda.amp import GradScaler, autocast
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hifi_wavegan.models.generator import ExWaveNetGenerator
from hifi_wavegan.models.discriminator import MultiPeriodDiscriminator, MultiResolutionSpectrogramDiscriminator
from hifi_wavegan.losses import HiFiWaveGANLoss
from hifi_wavegan.dataset import VocoderDataset
from hifi_wavegan.config import HiFiWaveGANConfig
def count_params(m): return sum(p.numel() for p in m.parameters() if p.requires_grad)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", default="/app/data/gtsinger")
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--total_steps", type=int, default=200000)
parser.add_argument("--hub_model_id", default="Frazun09/hifi-wavegan-48khz")
args = parser.parse_args()
cfg = HiFiWaveGANConfig()
cfg.training.data_dir = args.data_dir
cfg.training.batch_size = args.batch_size
cfg.training.total_steps = args.total_steps
cfg.training.hub_model_id = args.hub_model_id
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# Download GTSinger if needed
from huggingface_hub import snapshot_download
if not os.path.exists(cfg.training.data_dir) or len(list(Path(cfg.training.data_dir).rglob("*.wav"))) < 100:
snapshot_download("AaronZ345/GTSinger", repo_type="dataset", local_dir=cfg.training.data_dir,
ignore_patterns=["*.md","*.json","*.py","*.txt","*Paired_Speech*"])
gen = ExWaveNetGenerator(cfg.generator.n_mels, cfg.generator.residual_ch, cfg.generator.skip_ch,
cfg.generator.n_stacks, cfg.generator.n_layers_per_stack,
cfg.generator.kernel_sizes, cfg.generator.hop_size,
cfg.generator.sample_rate, cfg.generator.use_pulse).to(device)
mpd = MultiPeriodDiscriminator(cfg.discriminator.mpd_periods).to(device)
mrsd = MultiResolutionSpectrogramDiscriminator(cfg.discriminator.mrsd_stft_params).to(device)
print(f"Generator: {count_params(gen)/1e6:.2f}M | MPD: {count_params(mpd)/1e6:.2f}M | MRSD: {count_params(mrsd)/1e6:.2f}M")
loss_fn = HiFiWaveGANLoss(cfg.loss.lambda_adv, cfg.loss.lambda_aux, cfg.loss.lambda_fm)
og = AdamW(gen.parameters(), lr=cfg.training.lr, betas=cfg.training.betas, weight_decay=cfg.training.weight_decay)
od = AdamW(list(mpd.parameters())+list(mrsd.parameters()), lr=cfg.training.lr, betas=cfg.training.betas, weight_decay=cfg.training.weight_decay)
sg, sd = ExponentialLR(og, cfg.training.lr_decay), ExponentialLR(od, cfg.training.lr_decay)
scg, scd = GradScaler(enabled=device.type=='cuda'), GradScaler(enabled=device.type=='cuda')
train_ds = VocoderDataset(cfg.training.data_dir, cfg.audio.segment_size, cfg.audio.sample_rate,
cfg.audio.n_mels, cfg.audio.hop_length, cfg.audio.win_length, cfg.audio.n_fft, 'train')
train_dl = torch.utils.data.DataLoader(train_ds, cfg.training.batch_size, shuffle=True,
num_workers=cfg.training.num_workers, pin_memory=True, drop_last=True)
gen.train(); mpd.train(); mrsd.train()
step, epoch, it = 0, 0, iter(train_dl)
t0 = time.time()
while step < cfg.training.total_steps:
try: batch = next(it)
except StopIteration:
epoch += 1; it = iter(train_dl); batch = next(it); sg.step(); sd.step()
wav, mel, pitch = batch['wav'].to(device), batch['mel'].to(device), batch['pitch'].to(device)
f0, uv = batch['f0'].to(device), batch['uv'].to(device)
B, _, Tf = mel.shape
od.zero_grad()
with autocast(enabled=device.type=='cuda', device_type=device.type):
z = torch.randn(B, 1, Tf, device=device)
with torch.no_grad(): fake = gen(z, mel, pitch, f0, uv)
ml = min(wav.shape[-1], fake.shape[-1])
rw, fw = wav[...,:ml], fake[...,:ml]
ro, _ = mpd(rw); fo, _ = mpd(fw)
rso, _ = mrsd(rw); fso, _ = mrsd(fw)
dl, dd = loss_fn.discriminator_loss(ro, fo, rso, fso)
scd.scale(dl).backward(); scd.unscale_(od)
torch.nn.utils.clip_grad_norm_(list(mpd.parameters())+list(mrsd.parameters()), cfg.training.grad_clip)
scd.step(od); scd.update()
og.zero_grad()
with autocast(enabled=device.type=='cuda', device_type=device.type):
z = torch.randn(B, 1, Tf, device=device); fake = gen(z, mel, pitch, f0, uv)
ml = min(wav.shape[-1], fake.shape[-1])
rw, fw = wav[...,:ml], fake[...,:ml]
ro, rf = mpd(rw); fo, ff = mpd(fw)
rso, rsf = mrsd(rw); fso, fsf = mrsd(fw)
gl, gd = loss_fn.generator_loss(fo, ff, fso, fsf, rf, rsf, fw, rw)
scg.scale(gl).backward(); scg.unscale_(og)
torch.nn.utils.clip_grad_norm_(gen.parameters(), cfg.training.grad_clip)
scg.step(og); scg.update()
step += 1
if step % cfg.training.log_interval == 0:
eta = (cfg.training.total_steps - step) / max((step)/(time.time()-t0), 1e-6) / 3600
print(f"Step {step}/{cfg.training.total_steps} | G={gd['total']:.2f} adv={gd['adv']:.3f} aux={gd['aux']:.3f} fm={gd['fm']:.3f} | D={dd['total']:.3f} | ETA={eta:.1f}h")
if step % cfg.training.save_interval == 0:
os.makedirs(cfg.training.checkpoint_dir, exist_ok=True)
torch.save({'gen': gen.state_dict(), 'mpd': mpd.state_dict(), 'mrsd': mrsd.state_dict(),
'og': og.state_dict(), 'od': od.state_dict(), 'step': step},
f"{cfg.training.checkpoint_dir}/step_{step}.pt")
# Push to Hub
from huggingface_hub import HfApi
torch.save(gen.state_dict(), "/app/generator.pt")
HfApi().upload_file("/app/generator.pt", cfg.training.hub_model_id, "generator.pt")
print(f"Pushed to https://huggingface.co/{cfg.training.hub_model_id}")
if __name__ == "__main__":
main()