| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torchvision import datasets, transforms |
| from torch.utils.data import DataLoader |
| from models import Generator, Discriminator |
| import os |
|
|
| |
| latent_dim = 100 |
| batch_size = 64 |
| n_epochs = 200 |
| lr = 0.0002 |
| beta1 = 0.5 |
|
|
| |
| os.makedirs('images', exist_ok=True) |
|
|
| |
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize([0.5], [0.5]) |
| ]) |
|
|
| dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform) |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
|
| |
| generator = Generator(latent_dim=latent_dim) |
| discriminator = Discriminator() |
|
|
| |
| adversarial_loss = nn.BCELoss() |
|
|
| |
| g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999)) |
| d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999)) |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| generator.to(device) |
| discriminator.to(device) |
| adversarial_loss.to(device) |
|
|
| print(f'Starting training on {device}...') |
|
|
| |
| for epoch in range(n_epochs): |
| for i, (real_imgs, _) in enumerate(dataloader): |
| batch_size = real_imgs.shape[0] |
| |
| |
| valid = torch.ones(batch_size, 1).to(device) |
| fake = torch.zeros(batch_size, 1).to(device) |
| |
| |
| real_imgs = real_imgs.to(device) |
| |
| |
| |
| |
| g_optimizer.zero_grad() |
| |
| |
| z = torch.randn(batch_size, latent_dim).to(device) |
| |
| |
| gen_imgs = generator(z) |
| |
| |
| g_loss = adversarial_loss(discriminator(gen_imgs), valid) |
| |
| g_loss.backward() |
| g_optimizer.step() |
| |
| |
| |
| |
| d_optimizer.zero_grad() |
| |
| |
| real_loss = adversarial_loss(discriminator(real_imgs), valid) |
| fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) |
| d_loss = (real_loss + fake_loss) / 2 |
| |
| d_loss.backward() |
| d_optimizer.step() |
| |
| if i % 100 == 0: |
| print(f'[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] ' |
| f'[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]') |
|
|
| |
| if epoch % 10 == 0: |
| with torch.no_grad(): |
| z = torch.randn(16, latent_dim).to(device) |
| gen_imgs = generator(z) |
| torch.save(gen_imgs, f'images/epoch_{epoch}.pt') |
|
|
| print('Training finished!') |