| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torchvision import datasets, transforms |
| from torch.utils.data import DataLoader |
| from models_conv import ConvGenerator, ConvDiscriminator |
| import os |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| |
| latent_dim = 100 |
| batch_size = 64 |
| n_epochs = 200 |
| lr = 0.00005 |
| n_critic = 5 |
| clip_value = 0.01 |
|
|
| |
| os.makedirs('images', exist_ok=True) |
| os.makedirs('checkpoints', exist_ok=True) |
|
|
| |
| writer = SummaryWriter('runs/wgan_training') |
|
|
| |
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize([0.5], [0.5]) |
| ]) |
|
|
| dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
|
| |
| generator = ConvGenerator(latent_dim=latent_dim) |
| discriminator = ConvDiscriminator() |
|
|
| |
| g_optimizer = optim.RMSprop(generator.parameters(), lr=lr) |
| d_optimizer = optim.RMSprop(discriminator.parameters(), lr=lr) |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| generator.to(device) |
| discriminator.to(device) |
|
|
| print(f'Starting training on {device}...') |
|
|
| |
| for epoch in range(n_epochs): |
| for i, (real_imgs, _) in enumerate(dataloader): |
| real_imgs = real_imgs.to(device) |
| |
| |
| |
| |
| d_optimizer.zero_grad() |
| |
| |
| z = torch.randn(real_imgs.size(0), latent_dim).to(device) |
| |
| |
| fake_imgs = generator(z).detach() |
| |
| |
| d_loss = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs)) |
| |
| d_loss.backward() |
| d_optimizer.step() |
| |
| |
| for p in discriminator.parameters(): |
| p.data.clamp_(-clip_value, clip_value) |
| |
| |
| if i % n_critic == 0: |
| |
| |
| |
| g_optimizer.zero_grad() |
| |
| |
| gen_imgs = generator(z) |
| |
| |
| g_loss = -torch.mean(discriminator(gen_imgs)) |
| |
| g_loss.backward() |
| g_optimizer.step() |
| |
| if i % 100 == 0: |
| print(f'[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] ' |
| f'[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]') |
| |
| |
| writer.add_scalar('D_loss', d_loss.item(), epoch * len(dataloader) + i) |
| writer.add_scalar('G_loss', g_loss.item(), epoch * len(dataloader) + i) |
| |
| |
| if epoch % 10 == 0: |
| torch.save({ |
| 'epoch': epoch, |
| 'generator_state_dict': generator.state_dict(), |
| 'discriminator_state_dict': discriminator.state_dict(), |
| 'g_optimizer_state_dict': g_optimizer.state_dict(), |
| 'd_optimizer_state_dict': d_optimizer.state_dict(), |
| }, f'checkpoints/wgan_checkpoint_epoch_{epoch}.pt') |
| |
| |
| with torch.no_grad(): |
| z = torch.randn(16, latent_dim).to(device) |
| gen_imgs = generator(z) |
| for j, img in enumerate(gen_imgs): |
| writer.add_image(f'generated_image_{j}', img, epoch) |
|
|
| print('Training finished!') |
| writer.close() |